library(rjags)

#######################################################
############     DATA SPLITTING                 #######
#######################################################

# There were 17 independent samples prepared with 3 subsamples per sample, or 51 total subsamples.
# We randomly selected 1 subsample per sample for calibration routine, keeping the other 2 for testing.
# This ensures that the subsamples in the calibration routine are independent from each other.
# Subsamples were selected using a random keyboard stroke to set the seed as follows:
#
#    set.seed(547392)
#    subsamples.calibration = (0:16)*3 + sample(c(1,2,3), 17, replace=TRUE)
#
# This resulted in the 1st, 4th, 9th, ... subsamples being selected for training the calibration,
# and the other subsamples retained for (in-sample) testing of the calibration.
# The training data are stored as trimineral_training.txt, while
# the test data are stored as trimineral_test_data.txt


#######################################################
############     LOAD TRAINING DATA             #######
#######################################################
### Uncomment and add directory path if not opening    
### RStudio from the directory where files are stored  
# setwd("<location>")

#### read in 'trimineral_training_data.txt', store as trimineral 
# 1st column (Sample) : 
#   Identification code, e.g. 5b refers to the subsample b of sample 5
# 2-4th columns (Aragonite, LMC, HMC): 
#   known percentages of samples
# 5-9th columns (A1, A2, C1_LMC, C2_HMC, H1): 
#   measured intensities of Aragonite (A1 primary, A2 half-height secondary),
#   low-Mg Calcite (C1_LMC), high-Mg calcite (C2_HMC), and halite (H1).
#   NA means the value was below detection levels and assumed to be 0;
#   only A1, C1_LMC, C2_LMC used in the analysis.
    
training = read.table("trimineral_training_data.txt", header = TRUE)

#######################################################
############     DATA MANIPULATION              #######
#######################################################

# 1. Replace NAs for measured intensities with 0s
# 2. Create a matrix indicating 0s
training[is.na(training)] = 0
nonzero = 1*(training[, c(5,7,8)] != 0)

#######################################################
############     DATA FORMATTING FOR JAGS       #######
#######################################################

input_data = list(pi=training[, 2:4],
                  I = training[, c(5,7,8)],
                  nonzero = nonzero,
                  N=nrow(training))

# Examine input_data to ensure it looks as expected
input_data

#######################################################
############     INITIAL VALUE FUNCTION FOR JAGS   ####
#######################################################
# function setting sensible initial values for parameters 
# and setting JAGS global seed
initial_values = function(){
  return(list(k2=runif(1, 2, 4), 
              k3=runif(1, 1, 2), 
              .RNG.name="base::Super-Duper",
              .RNG.seed=7,
              kappa = runif(1, 5, 9)
              ))
}

#######################################################
############     COMPILE, UPDATE, SAMPLE           ####
#######################################################

# We ran a large number of samples for numerical stability
# (4 chains by 10k adaptation, 10k update, 125k sampling).
# To test the code, it may be desirable to reduce the
# number of MCMC iterations initially.

model_fit = jags.model(file = "trimineral_model.txt", 
                       data = input_data,
                       inits = initial_values,
                       n.chains = 4, 
                       n.adapt = 10000)

update(model_fit, 10000)

# pulling output from the model fit
model_output = coda.samples(model_fit,
                            c("k2",
                              "k3",
                              "kappa",
                              "T1", "T1.rep", "T2", "T2.rep", "T3", "T3.rep",
                              "Bayes.p1", "Bayes.p2", "Bayes.p3"),
                            125000)

#######################################################
############     CALIBRATION RESULTS               ####
#######################################################

# printing output from model
summary(model_output)

# calculate values for Excel spreadsheet
tmp = as.matrix(model_output)
var.k2 = cov(tmp[,10:11])[1,1]
var.k3 = cov(tmp[,10:11])[2,2]
cov.k2k3 = cov(tmp[,10:11])[1,2]
cov.ests = c(var.k2, var.k3, cov.k2k3)
names(cov.ests) = c("var.k2", "var.k3", "cov.k2k3")
ests = summary(model_output)[[2]][10:12,3]
Bayes.p = summary(model_output)[[1]][1:3,1]

# Print output required for Excel calculations
ests
cov.ests


# Print Bayesian p-values for Excel notes
Bayes.p

# Numerical diagnostics to ensure convergence
gelman.diag(model_output)
effectiveSize(model_output)



#######################################################
############     ESTIMATION FOR NEW SAMPLES in R   ####
#######################################################

# (A) Define the function that estimates compositions
#     and uses the multivariate delta method to estimate uncertainty
#     as described in Appendix A
# (B) Estimate composition for the withheld samples in trimineral_testing_data.txt
# (C) Estimate composition for the training data in trimineral_training_data.txt
#     (but noting this data was used to develop the model)

###
# (A) This function implements equations presented in Appendix A
# and is equivalent to the Excel spreadsheet (trimineral_estimation.xlsx) 
# that is provided in the data repository.
#
# Inputs:
#   ests : a vector of length 3 with estimated values for k2, k3, kappa
#   cov.ests : a vector of length 3 with estimates for var.k2, var.k3, cov.k2k3
#   I : a matrix with three columns of measured intensities 
#       (Aragonite primary, Low-Mg Calcite, High-Mg Calcite)
#       and n rows, where n is the number of samples
# Outputs:
#   pihat : n x 3 matrix with estimated proportions
#           columns: (Aragonite, Low-Mg Calcite, High-Mg Calcite
#           rows represent each sample
#           
###
composition_estimator = function(ests, cov.ests, I) {
  k1 = 1; k2 = ests[1]; k3 = ests[2]; kappa = ests[3]
  var.k2 = cov.ests[1]; var.k3 = cov.ests[2]; cov.k2k3 = cov.ests[3]
  n = nrow(I)
  I[is.na(I)] = 0 # Convert NAs to 0
  pihat1 = pihat2 = pihat3 = rep(NA, n) #estimated proportions
  i.var1 = i.var2 = i.var3 = rep(NA, n) #inherent variances
  c.var1 = c.var2 = c.var3 = rep(NA, n) #calibration variance
  for(i in 1:n) {
    I1 = I[i,1]; I2 = I[i,2]; I3 = I[i,3]
    denom = sum(k2*k3*I1 + k1*k3*I2 + k1*k2*I3)
    pihat1[i] = k2*k3*I1/denom
    pihat2[i] = k1*k3*I2/denom 
    pihat3[i] = k1*k2*I3/denom
    
    if(pihat1[i] == 0 | pihat1[i] == 1) { i.var1[i] = 0; c.var1[i] = 0 } else {
       phi2.1 = kappa*pihat1[i]*(1-pihat1[i])/I1
       i.var1[i] = phi2.1*( k2*k3*pihat1[i]*(1-pihat1[i])^2 + k1*pihat1[i]^2*(k3*pihat2[i] + k2*pihat3[i]) ) / ( k2*k3*(1 - pihat1[i]) ) 
       # partials
       ## Not used dp1.dk1 = -k2*k3*I1*(k3*I2 + k2*I3) / denom^2
       dp1.dk2 = k1*k3^2*I1*I2/ denom^2
       dp1.dk3 = k1*k2^2*I1*I3/ denom^2
       c.var1[i] = var.k2*dp1.dk2^2 + var.k3*dp1.dk3^2 + 2*cov.k2k3*dp1.dk2*dp1.dk3
       }
  
    if(pihat2[i] == 0 | pihat2[i] == 1) { i.var2[i] = 0; c.var2[i] = 0 } else {
      phi2.2 = kappa*pihat2[i]*(1-pihat2[i])/I2
      i.var2[i] = phi2.2*( k1*k3*pihat2[i]*(1-pihat2[i])^2 + k2*pihat2[i]^2*(k3*pihat1[i] + k2*pihat3[i]) ) / ( k1*k3*(1 - pihat2[i]) )
      dp2.dk2 = -k1*k3*I2*(k3*I1 + k1*I3) / denom^2
      dp2.dk3 = k2*k1^2*I2*I3/ denom^2
      c.var2[i] = var.k2*dp2.dk2^2 + var.k3*dp2.dk3^2 + 2*cov.k2k3*dp2.dk2*dp2.dk3
    }
  
    if(pihat3[i] == 0 | pihat3[i] == 1) { i.var3[i] = 0; c.var3[i] = 0 } else {
      phi2.3 = kappa*pihat3[i]*(1-pihat3[i])/I3
      i.var3[i] = phi2.3*( k1*k2*pihat3[i]*(1-pihat3[i])^2 + k3*pihat3[i]^2*(k2*pihat1[i] + k1*pihat3[i]) ) / ( k1*k2*(1 - pihat3[i]) )
      dp3.dk3 = -k1*k2*I3*(k2*I1 + k1*I2) / denom^2
      dp3.dk2 = k3*k1^2*I3*I2/ denom^2
      c.var3[i] = var.k2*dp3.dk2^2 + var.k3*dp3.dk3^2 + 2*cov.k2k3*dp3.dk2*dp3.dk3
    }
  }
  
  pihat = data.frame(pihat1= pihat1, pihat2=pihat2, pihat3=pihat3)
  i.se = data.frame(i.se1 = sqrt(i.var1), i.se2 = sqrt(i.var2), i.se3 = sqrt(i.var3))
  c.se = data.frame(c.se1 = sqrt(c.var1), c.se2 = sqrt(c.var2), c.se3 = sqrt(c.var3))
  combined.se = sqrt(i.se^2 + c.se^2)
  return(list(pihat=pihat, intrinsic.se = i.se, calibration.se = c.se, combined.se = combined.se))
}

###
# Load test data, extract measured intensities from 5th, 7th, 8th columns (Arag, low-Mg calcite, high-Mg calcite)
###
test = read.table("trimineral_test_data.txt", header = TRUE)
test.I = test[,c(5,7,8)]

###
# (B) ESTIMATING COMPOSITION OF THE TEST DATA 
# 
# 1. Note, each run of the calibration model will produce slightly different parameter estimates due
#    to Monte Carlo error. MC error reduces as the number of iterations increases, but never goes away completely.
# 2. Analysis is therefore run twice: once with values as entered in the Excel spreadsheet,
#    and once with values from the specific calibration results from the previous section.
composition_estimator(ests=ests, cov.ests=cov.ests, I=test.I)
composition_estimator(ests=c(3.50, 1.49, 43.33), cov.ests=c(0.0958, 0.0232, 0.0244), I=test.I)

###
# (C) ESTIMATING COMPOSITION OF THE TRAINING DATA
training.I = training[,c(5,7,8)]
composition_estimator(ests=ests, cov.ests=cov.ests, I=training.I)
